{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "29q48ihMZ1CA"
      },
      "outputs": [],
      "source": [
        "# Copyright 2022 Google LLC\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "\n",
        "from colabtools import adhoc_import\n",
        "from typing import Any, Sequence\n",
        "import ml_collections\n",
        "import numpy as np\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import functools\n",
        "import itertools\n",
        "from scipy.stats import mode\n",
        "from collections import defaultdict\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib.pyplot import cm\n",
        "import time\n",
        "import pickle\n",
        "from sklearn.cluster import KMeans, MiniBatchKMeans\n",
        "from sklearn.decomposition import PCA\n",
        "from sklearn.cluster import AgglomerativeClustering\n",
        "from sklearn.metrics import adjusted_mutual_info_score\n",
        "from sklearn.mixture import GaussianMixture\n",
        "\n",
        "from clu import preprocess_spec\n",
        "from scipy.special import comb\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wjfsP-qeZ5Vl"
      },
      "outputs": [],
      "source": [
        "def get_learning_rate(step: int,\n",
        "                      *,\n",
        "                      base_learning_rate: float,\n",
        "                      steps_per_epoch: int,\n",
        "                      num_epochs: int,\n",
        "                      warmup_epochs: int = 5):\n",
        "  \"\"\"Cosine learning rate schedule.\"\"\"\n",
        "  logging.info(\n",
        "      \"get_learning_rate(step=%s, base_learning_rate=%s, steps_per_epoch=%s, num_epochs=%s\",\n",
        "      step, base_learning_rate, steps_per_epoch, num_epochs)\n",
        "  if steps_per_epoch \u003c= 0:\n",
        "    raise ValueError(f\"steps_per_epoch should be a positive integer but was \"\n",
        "                     f\"{steps_per_epoch}.\")\n",
        "  if warmup_epochs \u003e= num_epochs:\n",
        "    raise ValueError(f\"warmup_epochs should be smaller than num_epochs. \"\n",
        "                     f\"Currently warmup_epochs is {warmup_epochs}, \"\n",
        "                     f\"and num_epochs is {num_epochs}.\")\n",
        "  epoch = step / steps_per_epoch\n",
        "  lr = cosine_decay(base_learning_rate, epoch - warmup_epochs,\n",
        "                    num_epochs - warmup_epochs)\n",
        "  warmup = jnp.minimum(1., epoch / warmup_epochs)\n",
        "  return lr * warmup\n",
        "\n",
        "def predict(model, state, batch):\n",
        "  \"\"\"Get intermediate representations from a model.\"\"\"\n",
        "  variables = {\n",
        "      \"params\": state.ema_params,\n",
        "      \"batch_stats\": state.batch_stats\n",
        "  }\n",
        "  _, state = model.apply(variables, batch['image'], capture_intermediates=True, mutable=[\"intermediates\"], train=False)\n",
        "  intermediates = state['intermediates']#['stage4']['__call__'][0]\n",
        "  return intermediates\n",
        "\n",
        "def compute_purity(clusters, classes):\n",
        "  \"\"\"Compute purity of the cluster.\"\"\"\n",
        "  n_cluster_points = 0\n",
        "  for cluster_idx in set(clusters):\n",
        "    instance_idx = np.where(clusters == cluster_idx)[0]\n",
        "    subclass_labels = classes[instance_idx]\n",
        "    mode_stats = mode(subclass_labels)\n",
        "    n_cluster_points += mode_stats[1][0]\n",
        "  purity = n_cluster_points / len(clusters)\n",
        "  return purity\n",
        "\n",
        "def show_images_horizontally(images, labels, super_label=None, max_samples=20, show_labels=False):\n",
        "  \"\"\"Display images in a row using matplotlib.\"\"\"\n",
        "  if max_samples is not None:\n",
        "    number_of_files = min(images.shape[0], max_samples)\n",
        "    fig = plt.figure(figsize=(1.5 * number_of_files, 1.5))\n",
        "    n_rows = 1\n",
        "    n_cols = number_of_files\n",
        "  else:\n",
        "    number_of_files = images.shape[0]\n",
        "    n_cols = 10\n",
        "    n_rows = int(np.ceil(number_of_files/n_cols))\n",
        "    fig = plt.figure(figsize=(1.5 * n_cols, 1.5 * n_rows))\n",
        "                          \n",
        "  for i in range(number_of_files):\n",
        "    axes = fig.add_subplot(n_rows, n_cols, i+1)\n",
        "    axes.imshow(images[i])\n",
        "    if show_labels:\n",
        "      if isinstance(labels, list):\n",
        "        axes.set_title(labels[i])\n",
        "      else:\n",
        "        axes.set_title(labels)\n",
        "  if super_label:\n",
        "    plt.title(super_label)\n",
        "  plt.show()\n",
        "\n",
        "def visualize_all_clusters(all_images, cluster_labels, all_subclass_labels, hier, gender_labels, show_labels=False):\n",
        "  \"\"\"Visualize images from each cluster.\"\"\"\n",
        "  for label in list(set(cluster_labels)):\n",
        "    instance_idx = np.where(cluster_labels == label)[0]\n",
        "    subclass_labels = all_subclass_labels[instance_idx]\n",
        "    mode_stats = mode(subclass_labels)\n",
        "    cluster_images = all_images[instance_idx]\n",
        "    cluster_label = mode_stats[0][0]\n",
        "    if hier:\n",
        "      cluster_label_str = hier.LEAF_NUM_TO_NAME[cluster_label].split(',')[0]\n",
        "      subclass_labels = [hier.LEAF_NUM_TO_NAME[s].split(',')[0] for s in subclass_labels]\n",
        "      show_images_horizontally(cluster_images, subclass_labels, super_label=cluster_label_str, max_samples=10, show_labels=show_labels)\n",
        "    elif gender_labels:\n",
        "      male_pc = sum(gender_labels[instance_idx]) * 100.0 / len(instance_idx)\n",
        "      cluster_label_str = f'Male ratio = {male_pc}%'\n",
        "      subclass_labels = ['' for s in subclass_labels]\n",
        "      show_images_horizontally(cluster_images, subclass_labels, super_label=cluster_label_str, max_samples=10, show_labels=show_labels)\n",
        "    else:\n",
        "      show_images_horizontally(cluster_images, subclass_labels, super_label='', max_samples=10, show_labels=show_labels)\n",
        "\n",
        "config = get_config()\n",
        "learning_rate_fn = functools.partial(\n",
        "      get_learning_rate,\n",
        "      base_learning_rate=0.1,\n",
        "      steps_per_epoch=40,\n",
        "      num_epochs=config.num_epochs,\n",
        "      warmup_epochs=config.warmup_epochs)\n",
        "\n",
        "\n",
        "def evaluate_purity(eval_dataset, model_dir, ckpt_number, n_classes, overcluster_factors, n_subclasses):\n",
        "  \"\"\"Given a model and a dataset, cluster the second-to-last layer representations and compute average purity.\"\"\"\n",
        "  checkpoint_path = os.path.join(model_dir, f'checkpoints-0/ckpt-{ckpt_number}.flax')\n",
        "  model, state = create_train_state(config, jax.random.PRNGKey(0), input_shape=(8, 224, 224, 3), \n",
        "                                    num_classes=n_classes, learning_rate_fn=learning_rate_fn)\n",
        "  state = checkpoints.restore_checkpoint(checkpoint_path, state)\n",
        "  print(\"Ckpt number\", ckpt_number, \"Ckpt step:\", state.step)\n",
        "\n",
        "  result_dict = {}\n",
        "  all_intermediates = []\n",
        "  all_subclass_labels = []\n",
        "  all_images = []\n",
        "  for step, batch in enumerate(eval_ds):\n",
        "    if step % 50 == 0:\n",
        "      print(step)\n",
        "    intermediates = predict(model, state, batch)\n",
        "    labels = batch['label'].numpy()\n",
        "    bs = labels.shape[0]\n",
        "    all_subclass_labels.append(labels)\n",
        "    all_images.append(batch['image'].numpy())\n",
        "    all_intermediates.append(np.mean(intermediates['stage4']['__call__'][0], axis=(1,2)).reshape(bs, -1))\n",
        "\n",
        "  all_intermediates = np.vstack(all_intermediates)\n",
        "  all_subclass_labels = np.hstack(all_subclass_labels)\n",
        "  all_images = np.vstack(all_images)\n",
        "\n",
        "  for overcluster_factor in overcluster_factors:\n",
        "    clf = AgglomerativeClustering(n_clusters=n_subclasses*overcluster_factor,\n",
        "                                          linkage='ward').fit(all_intermediates)\n",
        "    all_clf_labels = clf.labels_\n",
        "    purity = compute_purity(clf.labels_, all_subclass_labels)\n",
        "    result_dict[overcluster_factor] = purity\n",
        "  return result_dict, all_clf_labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FsabKV0qna2f"
      },
      "source": [
        "# Load OOD dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C2e3b75InaCF"
      },
      "outputs": [],
      "source": [
        "DATASET = 'oxford_iiit_pet'\n",
        "SPLIT = tfds.Split.TRAIN\n",
        "dataset_builder = tfds.builder(DATASET, try_gcs=True)\n",
        "eval_preprocess = preprocess_spec.PreprocessFn([\n",
        "    RescaleValues(),\n",
        "    ResizeSmall(256),\n",
        "    CentralCrop(224),\n",
        "    GeneralPreprocessOp(),\n",
        "    ], only_jax_types=True)\n",
        "dataset_options = tf.data.Options()\n",
        "dataset_options.experimental_optimization.map_parallelization = True\n",
        "dataset_options.experimental_threading.private_threadpool_size = 48\n",
        "dataset_options.experimental_threading.max_intra_op_parallelism = 1\n",
        "read_config = tfds.ReadConfig(shuffle_seed=None, options=dataset_options)\n",
        "eval_ds = dataset_builder.as_dataset(\n",
        "    split=SPLIT,\n",
        "    shuffle_files=False,\n",
        "    read_config=read_config,\n",
        "    decoders=None)\n",
        "batch_size = 64\n",
        "eval_ds = eval_ds.cache()\n",
        "eval_ds = eval_ds.map(eval_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "eval_ds = eval_ds.batch(batch_size, drop_remainder=False)\n",
        "eval_ds = eval_ds.prefetch(tf.data.experimental.AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eh2Zs6fynUhs"
      },
      "source": [
        "# Evaluate purity on OOD data for a single model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jGco4VbfZ5vI"
      },
      "outputs": [],
      "source": [
        "model_dir = os.path.join(BASE_DIR, 'breeds/entity13_400_epochs_ema_0.99_bn_0.99/')\n",
        "ckpt_number = 161\n",
        "N_CLASSES = 13\n",
        "model_dir = os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/')\n",
        "ckpt_number = 173\n",
        "N_CLASSES = 17\n",
        "# model_dir = os.path.join(BASE_DIR, 'breeds/entity13_4_subclasses_shuffle_400_epochs_ema_0.99_bn_0.99/')\n",
        "# ckpt_number = 129\n",
        "# N_CLASSES = 13\n",
        "model_dir = os.path.join(BASE_DIR, 'breeds/nonliving26_400_epochs_ema_0.99_bn_0.99/')\n",
        "ckpt_number = 257\n",
        "N_CLASSES = 26\n",
        "model_dir = os.path.join(BASE_DIR, 'breeds/imagenet_ema_0.99_bn_0.99/')\n",
        "ckpt_number = 8\n",
        "N_CLASSES = 1000\n",
        "overcluster_factors = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n",
        "n_subclasses = 5\n",
        "\n",
        "result_dict, all_clf_labels = evaluate_purity(eval_ds, model_dir, ckpt_number, N_CLASSES, overcluster_factors, n_subclasses)\n",
        "train_subset = model_dir.split('/')[-2].split('_')[0]\n",
        "print(result_dict)\n",
        "plt.plot(result_dict.keys(), result_dict.values(), marker='o')\n",
        "plt.xlabel(\"Overclustering factor\")\n",
        "plt.ylabel(\"Purity\")\n",
        "plt.title(f\"{train_subset} -\u003e {DATASET}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zreszJSinQPu"
      },
      "source": [
        "# Evaluate purity on OOD data for multiple models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "scp4zw0yVjP9"
      },
      "outputs": [],
      "source": [
        "model_metadata = [(os.path.join(BASE_DIR, 'breeds/entity13_400_epochs_ema_0.99_bn_0.99/'), 161, 13),\n",
        "                  (os.path.join(BASE_DIR, 'breeds/living17_400_epochs_ema_0.99_bn_0.99/'), 173, 17),\n",
        "                  (os.path.join(BASE_DIR, 'breeds/nonliving26_400_epochs_ema_0.99_bn_0.99/'), 257, 26),\n",
        "                  (os.path.join(BASE_DIR, 'breeds/imagenet_ema_0.99_bn_0.99/'), 8, 1000)\n",
        "                  ]\n",
        "\n",
        "fig = plt.figure()\n",
        "for model_dir, ckpt_number, N_CLASSES in model_metadata:\n",
        "  result_dict, _ = evaluate_purity(eval_ds, model_dir, ckpt_number, N_CLASSES, overcluster_factors, n_subclasses)\n",
        "  print(result_dict)\n",
        "  leg = model_dir.split('/')[-2].split('_')[0]\n",
        "  plt.plot(result_dict.keys(), result_dict.values(), marker='o', label=leg)\n",
        "plt.xlabel(\"Overclustering factor\")\n",
        "plt.ylabel(\"Purity\")\n",
        "plt.legend()\n",
        "plt.title(f\"{DATASET}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MTKEWnzdnKQK"
      },
      "source": [
        "# Visualize sample images from some clusters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tQm0puaqPD6r"
      },
      "outputs": [],
      "source": [
        "for clf_label in range(15):\n",
        "  clf_idx = np.where(np.array(all_clf_labels) == clf_label)[0]\n",
        "  visualize_all_clusters(all_images[clf_idx], all_clf_labels[clf_idx], all_subclass_labels[clf_idx], None, None)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1GdimBG3IZBdYGTG_mlz822n1CHb5IG2Q",
          "timestamp": 1664583179924
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
